package com.pomelo.chat.service.impl;

import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pomelo.chat.config.ChatConfig;
import com.pomelo.chat.config.ChatGPTConfig;
import com.pomelo.chat.domain.Chat;
import com.pomelo.chat.domain.ChatVo;
import com.pomelo.chat.domain.Prompt;
import com.pomelo.chat.listener.ChatGPTSseEventSourceListener;
import com.pomelo.chat.service.ChatGPTService;
import com.pomelo.chat.service.ChatService;
import com.pomelo.chat.util.Common;
import com.pomelo.chat.util.Constant;
import com.pomelo.chat.util.Question;
import com.theokanning.openai.OpenAiService;
import com.theokanning.openai.completion.CompletionChoice;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.unfbx.chatgpt.OpenAiClient;
import com.unfbx.chatgpt.OpenAiStreamClient;
import com.unfbx.chatgpt.entity.chat.ChatChoice;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.entity.common.Choice;
import com.unfbx.chatgpt.entity.completions.Completion;
import com.unfbx.chatgpt.entity.completions.CompletionResponse;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import org.jetbrains.annotations.NotNull;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.text.ParseException;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.concurrent.TimeUnit;

@Service
@Slf4j
@Data
public class ChatGPTServiceImpl implements ChatGPTService {

    @Resource
    private ChatConfig chatConfig;

    @Resource
    private ChatService chatService;

    @Resource
    private RedisTemplate<String, List<Chat>> redisTemplate;

    private OpenAiService openAiService;
    private ChatGPTConfig chatGPTConfig;
    private OpenAiClient openAiClient;
    private OpenAiStreamClient client;
    private Long start_time;

    /**
     * 使用的 api ： https://github.com/TheoKanning/openai-java
     *
     * @param question 问题实体
     * @param ip       请求的ip
     * @return 结果
     */
    @Override
    public Chat askGPT3(Question question, String ip) {
        if (openAiService == null) {
            openAiService = new OpenAiService(chatConfig.getToken(), Duration.ofSeconds(chatConfig.getTimeOut()));
            chatGPTConfig = chatConfig.getChatGPT();
        }
        return ask(question, ip);
    }

    /**
     * 使用新的api：https://github.com/Grt1228/chatgpt-java
     *
     * @param question 问题实体
     * @param ip       访问ip
     */
    @Override
    public ChatVo askChatGPT(Question question, String ip) throws ParseException {
        //if (openAiClient == null) {
        //    openAiClient = new OpenAiClient(chatConfig.getToken(), chatConfig.getTimeOut(), chatConfig.getWriteTimeOut(), chatConfig.getReadTimeOut());
        //    chatGPTConfig = chatConfig.getChatGPT();
        //}
        //return ask2(question, ip);
        return null;
    }

    /**
     * 使用模型 GPT-3.5-TURBO
     * 自定义监听器 使用 SseEmitter
     *
     * @return
     */
    @Override
    public SseEmitter askChatGPT35(String question, HttpServletRequest request) throws IOException {
        log.info("question:" + question);
        this.setStart_time(System.currentTimeMillis());
        SseEmitter sseEmitter = new SseEmitter(0L);
        if (client == null) {
            OkHttpClient okHttpClient = new OkHttpClient
                    .Builder()
                    .proxy(chatConfig.getProxyPort() == null ? null : new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", chatConfig.getProxyPort())))//自定义代理
                    .connectTimeout(chatConfig.getTimeOut(), TimeUnit.SECONDS)  // 自定义超时时间
                    .writeTimeout(chatConfig.getWriteTimeOut(), TimeUnit.SECONDS)  // 自定义超时时间
                    .readTimeout(chatConfig.getReadTimeOut(), TimeUnit.SECONDS)  // 自定义超时时间
                    .build();
            client = OpenAiStreamClient.builder()
                    .apiKey(Collections.singletonList(chatConfig.getToken()))
                    .apiHost(Constant.API_HOST)
                    .okHttpClient(okHttpClient)
                    .build();
            chatGPTConfig = chatConfig.getChatGPT();
        }
        ChatGPTSseEventSourceListener sourceListener = new ChatGPTSseEventSourceListener(sseEmitter);
        sourceListener.setQuestion(question);
        sourceListener.setRequest(request);
        sseEmitter.send(SseEmitter.event().
                id("GaoYang")
                .name("连接成功")
                .data(LocalDateTime.now())
                .reconnectTime(Constant.RECONNECT_TIME));
        sseEmitter.onCompletion(() -> {
            Prompt prompt = new Prompt(question, Common.getIpAddress(request), sourceListener.getAllData().toString(), (int) (System.currentTimeMillis() - this.start_time));
            Chat chat = insert(sourceListener.getChatCompletionResponse(), prompt);
            if (chat == null) {
                log.info("数据库插入失败");
            } else {
                log.info("数据库插入成功");
                List<Chat> list = redisTemplate.opsForValue().get(Constant.REDIS_KEY);
                if (list != null && list.size() > 0) {
                    ArrayList<Chat> chats = new ArrayList<>(list);
                    if (chats.size() == chatConfig.getSize()) {
                        chats.remove(0);
                    }
                    chats.add(chat);
                    redisTemplate.opsForValue().set(Constant.REDIS_KEY, chats, Constant.REDIS_EXPIRE_TIME, TimeUnit.SECONDS);
                } else {
                    redisTemplate.opsForValue().set(Constant.REDIS_KEY, Collections.singletonList(chat), Constant.REDIS_EXPIRE_TIME, TimeUnit.SECONDS);
                }
                log.info("redis插入成功");
            }
        });
        sseEmitter.onError(throwable -> {
            try {
                log.info(LocalDateTime.now() + ", uid#" + "765431" + ", on error#" + throwable.toString());
                sseEmitter.send(SseEmitter.event().id("765431").name("发生异常！").data(throwable.getMessage()).reconnectTime(3000));
            } catch (IOException e) {
                log.error(e.getMessage());
            }
        });
        List<Message> messages = new ArrayList<>();
        List<Chat> chatList = redisTemplate.opsForValue().get(Constant.REDIS_KEY);
        if (chatList != null && chatList.size() > 0) {
            ArrayList<Chat> chats = new ArrayList<>(chatList);
            chats.forEach(ch -> {
                messages.add(Message.builder().content(ch.getPrompt()).role(Message.Role.USER).build());
                messages.add(Message.builder().content(ch.getAnswer()).role(Message.Role.ASSISTANT).build());
            });
        }
        messages.add(Message.builder().content(question).role(Message.Role.USER).build());
        ChatCompletion chatCompletion = ChatCompletion.builder().model(chatGPTConfig.getModel()).messages(messages).build();
        client.streamChatCompletion(chatCompletion, sourceListener);
        return sseEmitter;
    }

    private Chat insert(@NotNull ChatCompletionResponse response, @NotNull Prompt prompt) {
        ChatChoice chatChoice = response.getChoices().get(0);
        chatChoice.getDelta().setContent(prompt.getData());
        Chat chat = new Chat();
        chat.setUsername("name");
        chat.setUserKey(chatConfig.getToken());
        chat.setModel(response.getModel());
        chat.setPrompt(prompt.getQuestion());
        chat.setAnswer(prompt.getData());
        chat.setIsEnd("stop".equals(chatChoice.getFinishReason()));
        chat.setResponseJson(JSON.toJSONString(response));
        List<Chat> chatList = redisTemplate.opsForValue().get(Constant.REDIS_KEY);
        ObjectMapper mapper = new ObjectMapper();
        try {
            // 如果是联系上下文 记录上一个问题的 id
            chat.setPromptId(chatList == null || chatList.size() == 0 ? response.getId() : mapper.readValue(chatList.get(0).getResponseJson(), ChatCompletionResponse.class).getId());
        } catch (JsonProcessingException e) {
            log.error(e.getMessage());
        }
        chat.setRequestIpAddress(prompt.getIp());
        chat.setRequestUrl(Constant.REQUEST_URL);
        chat.setCreateTime(new Date());
        chat.setRequestTime(Long.valueOf(prompt.getRequestTime()));
        chat.setMaxTokens(chatGPTConfig.getMaxTokens());
        chat.setTemperature(chatGPTConfig.getTemperature());
        chat.setThreadName(Thread.currentThread().getName());
        int insert = chatService.insert(chat);
        if (insert > 0) {
            return chat;
        } else {
            return null;
        }
    }

    private ChatVo ask2(Question question, String ip) throws ParseException {
        log.info(question.getQ());
        Completion completion = Completion.builder()
                .prompt(question.getQ())
                .model(chatGPTConfig.getModel())
                .temperature(chatGPTConfig.getTemperature())
                .maxTokens(chatGPTConfig.getMaxTokens())
                .topP(chatGPTConfig.getTopP())
                .echo(chatGPTConfig.getEcho())
                .build();
        long start = System.currentTimeMillis();
        CompletionResponse response = openAiClient.completions(completion);
        long end = System.currentTimeMillis();
        Choice choice = response.getChoices()[0];
        Chat chat = new Chat();
        chat.setUsername("GaoYang");
        chat.setUserKey(chatConfig.getToken());
        chat.setModel(response.getModel());
        chat.setPrompt(question.getQ());
        chat.setAnswer(choice.getText());
        chat.setIsEnd("stop".equals(choice.getFinishReason()));
        chat.setRequestIpAddress(ip);
        chat.setResponseJson(JSON.toJSONString(response));
        chat.setPromptId(response.getId());
        chat.setFinishReason(choice.getFinishReason());
        chat.setRequestUrl(Constant.REQUEST_URL);
        chat.setRequestTime(end - start);
        chat.setMaxTokens(chatGPTConfig.getMaxTokens());
        chat.setTemperature(chatGPTConfig.getTemperature());
        chat.setThreadName(Thread.currentThread().getName());
        int insert = chatService.insert(chat);
        if (insert < 0) {
            log.error("插入数据库失败！");
        } else {
            log.info("插入数据成功！");
        }
        return new ChatVo(chat);
    }


    private Chat ask(Question question, String ip) {
        log.info(question.getQ());
        CompletionRequest completionRequest = CompletionRequest.builder()
                .prompt(question.getQ())
                .model(chatGPTConfig.getModel())
                .temperature(chatGPTConfig.getTemperature())
                .maxTokens(chatGPTConfig.getMaxTokens())
                .topP(Double.valueOf(chatGPTConfig.getTopP()))
                .echo(chatGPTConfig.getEcho())
                .build();
        long start = System.currentTimeMillis();
        CompletionResult result = openAiService.createCompletion(completionRequest);
        long end = System.currentTimeMillis();
        List<CompletionChoice> choices = result.getChoices();
        CompletionChoice choice = choices.get(0);
        Chat chat = new Chat();
        chat.setUsername("GaoYang");
        chat.setUserKey(chatConfig.getToken());
        chat.setModel(result.getModel());
        chat.setPrompt(question.getQ());
        chat.setAnswer(choice.getText());
        chat.setIsEnd("stop".equals(choice.getFinish_reason()));
        chat.setRequestIpAddress(ip);
        chat.setResponseJson(JSON.toJSONString(result));
        chat.setPromptId(result.getId());
        chat.setFinishReason(choice.getFinish_reason());
        chat.setRequestUrl(Constant.REQUEST_URL);
        chat.setCreateTime(new Date());
        chat.setRequestTime(end - start);
        chat.setMaxTokens(chatGPTConfig.getMaxTokens());
        chat.setTemperature(chatGPTConfig.getTemperature());
        chat.setThreadName(Thread.currentThread().getName());
        int insert = chatService.insert(chat);
        if (insert < 0) {
            log.error("插入数据库失败！");
        } else {
            log.info("插入数据成功！");
        }
        return chat;
    }


}
